package gsssh

import (
	"gitee.com/Sxiaobai/gs/v2/gstool"
	"io"
	"net"
	"time"
)

type Bridge struct {
	localListener  net.Listener
	remoteConn     net.Conn
	localConn      net.Conn
	targetHostPort string
	localHostPort  string
}

// RunBridge execute and forward to the local port
// Local listening port, all connections to this port will be forwarded to the destination address
func (h *SshConfig) RunBridge(targetHostPort string) (string, error) {
	if !h.isRunning {
		connectionErr := h.ConnectAuthPassword()
		if connectionErr != nil {
			return ``, connectionErr
		}
	}
	h.bridge.targetHostPort = targetHostPort
	//监听本地端口
	listenLocalErr := h.startListenLocal()
	if listenLocalErr != nil {
		return ``, listenLocalErr
	}
	//监听目标端口
	listenRemoteErr := h.createRemoteConn()
	if listenRemoteErr != nil {
		return ``, listenRemoteErr
	}
	//创建链接
	go h.createConn()
	h.RunType = RunTypeBridge
	return h.bridge.localHostPort, nil
}

func (h *SshConfig) startListenLocal() error {
	var localListenerErr error
	h.bridge.localListener, localListenerErr = net.Listen("tcp", "127.0.0.1:0")
	if localListenerErr != nil {
		return gstool.Error("监听本地端口失败: %s", localListenerErr.Error())
	}
	h.bridge.localHostPort = h.bridge.localListener.Addr().String()
	return nil
}

func (h *SshConfig) createRemoteConn() error {
	var remoteConnErr error
	h.bridge.remoteConn, remoteConnErr = h.client.Dial("tcp", h.bridge.targetHostPort)
	if remoteConnErr != nil {
		return gstool.Error("连接目标端口失败: %s", remoteConnErr.Error())
	}
	return nil
}

func (h *SshConfig) createLocalConn() error {
	var localConnErr error
	h.bridge.localConn, localConnErr = h.bridge.localListener.Accept()
	if localConnErr != nil {
		h.Errof(`接收链接到本地监听端口失败 %s`, localConnErr.Error())
		return localConnErr
	}
	return nil
}

// bridgeListenLocal 等待本地端口连接
// 注意：只有真正发送请求的时候（例如ping） 才会开始执行 如果仅仅是监听本地端口是不会执行后续的东西的
func (h *SshConfig) createConn() {
	localErr := h.createLocalConn()
	if localErr != nil {
		gstool.FmtPrintlnLogTime(`创建local conn失败 %s`, localErr.Error())
		return
	}
	// 复制流量
	h.transferCopy(true, true)
}

// The transfer function runs in the goroutine and is used to copy data between two connections
func (h *SshConfig) transferCopy(boolLTR, boolRTL bool) {
	go func() {
		h.ioCopyLocalToRemote()
	}()
	go func() {
		if boolRTL {
			h.ioCopyRemoteToLocal()
		}
	}()
}

func (h *SshConfig) ioCopyLocalToRemote() {
	defer func() {
		if r := recover(); r != nil {
			gstool.FmtPrintlnLogTime("ioCopyLocalToRemote warn:%v", r)
		}
	}()
	_, _ = io.Copy(h.bridge.localConn, h.bridge.remoteConn)
}

func (h *SshConfig) ioCopyRemoteToLocal() {
	defer func() {
		if r := recover(); r != nil {
			gstool.FmtPrintlnLogTime("ioCopyRemoteToLocal warn:%v", r)
		}
	}()
	ret, _ := io.Copy(h.bridge.remoteConn, h.bridge.localConn)
	if ret > 0 {
		h.CloseBridge()
	}
}

func (h *SshConfig) CloseBridge() {
	_ = h.bridge.remoteConn.Close()
	_ = h.bridge.localConn.Close()
	time.Sleep(time.Second * 2)
	localErr := h.createLocalConn()
	if localErr != nil {
		gstool.FmtPrintlnLogTime(`local err %s`, localErr.Error())
	}
	err := h.createRemoteConn()
	if err != nil {
		gstool.FmtPrintlnLogTime(`重连失败 %s`, err.Error())
	}
	h.transferCopy(false, true)
}
